{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "## import torchtext\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from torchtext.vocab import Vectors\n",
    "import torchtext\n",
    "from tqdm import tqdm_notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = torchtext.data.Field(include_lengths = False)\n",
    "label = torchtext.data.Field(sequential=False)\n",
    "train, val, test = torchtext.datasets.SST.splits(text, label, filter_pred=lambda ex: ex.label != 'neutral')\n",
    "text.build_vocab(train)\n",
    "label.build_vocab(train)\n",
    "train_iter, val_iter, test_iter = torchtext.data.BucketIterator.splits((train, val, test), batch_size=10, device=-1, repeat = False)\n",
    "url = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.simple.vec'\n",
    "text.vocab.load_vectors(vectors=Vectors('wiki.simple.vec', url=url))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self, in_channels, out_channels, batch_size):\n",
    "        super(CNN, self).__init__()\n",
    "        self.embeddings = nn.Embedding(text.vocab.vectors.size()[0], text.vocab.vectors.size()[1])\n",
    "        self.embeddings.weight.data.copy_(text.vocab.vectors)\n",
    "        self.convs = nn.ModuleList([nn.Conv1d(in_channels = in_channels, out_channels = in_channels, kernel_size = n) for n in (2,3,4)])\n",
    "        self.dropout_train, self.dropout_test = nn.Dropout(p = 0.5), nn.Dropout(p = 0)\n",
    "        self.linear = nn.Linear(in_features=in_channels, out_features=out_channels, bias = True)\n",
    "    \n",
    "    def forward(self, x, train = True):\n",
    "        embedded = self.embeddings(x)\n",
    "        embedded = embedded.transpose(1, 2)\n",
    "        embedded = embedded.transpose(0, 2)\n",
    "        concatted_features = torch.cat([conv(embedded) for conv in self.convs if embedded.size(2) >= conv.kernel_size[0]], dim = 2)\n",
    "        activated_features = nn.functional.relu(concatted_features)\n",
    "        pooled = nn.functional.max_pool1d(activated_features, activated_features.size(2)).squeeze(2)\n",
    "        dropped = self.dropout_train(pooled) if train else self.dropout_test(pooled)\n",
    "        output = self.linear(dropped)\n",
    "        logits = nn.functional.log_softmax(output, dim = 1)\n",
    "        return logits\n",
    "\n",
    "    def predict(self, x):\n",
    "        logits = self.forward(x, train = False)\n",
    "        return logits.max(1)[1] + 1\n",
    "    \n",
    "    def train(self, train_iter, val_iter, num_epochs, learning_rate = 1e-3):\n",
    "        criterion = nn.NLLLoss()\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)\n",
    "        loss_vec = []\n",
    "        \n",
    "        for epoch in tqdm_notebook(range(1, num_epochs+1)):\n",
    "            epoch_loss = 0\n",
    "            for batch in train_iter:\n",
    "                x = batch.text\n",
    "                y = batch.label\n",
    "                \n",
    "                optimizer.zero_grad()\n",
    "                \n",
    "                y_p = self.forward(x)\n",
    "                \n",
    "                loss = criterion(y_p, y-1)\n",
    "                loss.backward()\n",
    "                \n",
    "                optimizer.step()\n",
    "                epoch_loss += loss.data[0]\n",
    "                \n",
    "            self.model = model\n",
    "            \n",
    "            loss_vec.append(epoch_loss / len(train_iter))\n",
    "            if epoch % 1 == 0:\n",
    "                acc = self.validate(val_iter)\n",
    "                print('Epoch {} loss: {} | acc: {}'.format(epoch, loss_vec[epoch-1], acc))\n",
    "                self.model = model\n",
    "        \n",
    "        plt.plot(range(len(loss_vec)), loss_vec)\n",
    "        plt.xlabel('Epoch')\n",
    "        plt.ylabel('Loss')\n",
    "        plt.show()\n",
    "        print('\\nModel trained.\\n')\n",
    "        self.loss_vec = loss_vec\n",
    "        self.model = model\n",
    "\n",
    "    def test(self, test_iter):\n",
    "        \"All models should be able to be run with following command.\"\n",
    "        upload, trues = [], []\n",
    "        # Update: for kaggle the bucket iterator needs to have batch_size 10\n",
    "        for batch in test_iter:\n",
    "            # Your prediction data here (don't cheat!)\n",
    "            x, y = batch.text, batch.label\n",
    "            probs = self.predict(x)[:len(y)]\n",
    "            upload += list(probs.data)\n",
    "            trues += list(y.data)\n",
    "            \n",
    "        correct = sum([1 if i == j else 0 for i, j in zip(upload, trues)])\n",
    "        accuracy = correct / len(trues)\n",
    "        print('Testset Accuracy:', accuracy)\n",
    "\n",
    "        with open(\"predictions.txt\", \"w\") as f:\n",
    "            for u in upload:\n",
    "                f.write(str(u) + \"\\n\")\n",
    "                \n",
    "    def validate(self, val_iter):\n",
    "        y_p, y_t, correct = [], [], 0\n",
    "        for batch in val_iter:\n",
    "            x, y = batch.text, batch.label\n",
    "            probs = self.model.predict(x)[:len(y)]\n",
    "            y_p += list(probs.data)\n",
    "            y_t += list(y.data)\n",
    "            \n",
    "        correct = sum([1 if i == j else 0 for i, j in zip(y_p, y_t)])\n",
    "        accuracy = correct / len(y_p)\n",
    "        return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a17574a929af4f219f22aca1e3412bd2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sob/Desktop/cs287/homework1/env/lib/python3.6/site-packages/ipykernel_launcher.py:33: DeprecationWarning: generator 'Iterator.__iter__' raised StopIteration\n",
      "/Users/sob/Desktop/cs287/homework1/env/lib/python3.6/site-packages/ipykernel_launcher.py:79: DeprecationWarning: generator 'Iterator.__iter__' raised StopIteration\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 loss: 0.5709313771154495 | acc: 0.7603211009174312\n",
      "Epoch 2 loss: 0.29662336955772917 | acc: 0.786697247706422\n",
      "Epoch 3 loss: 0.10009956448352385 | acc: 0.7775229357798165\n",
      "Epoch 4 loss: 0.03353736516747484 | acc: 0.7729357798165137\n",
      "Epoch 5 loss: 0.01669091764996418 | acc: 0.7775229357798165\n",
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAELCAYAAAA2mZrgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAH8tJREFUeJzt3Xt8VPW57/HPkyuEayIhIBdBQTERRRgRsbo9tVC0fUFboYJStfXeemxPe9q6t3uf07pPL9ae1q2lUrxVrRRv3T3o9oJV66UiEvCCCReRi4ICAcI1QBJ4zh8zxCEmZAJZs+byfb9eeTlrzS8zj0sn31nrN89vzN0REREByAm7ABERSR0KBRERaaJQEBGRJgoFERFpolAQEZEmCgUREWkSaCiY2QQzW25mK83splbGfN3Mqs2sysxmB1mPiIgcngXVp2BmucAKYBywDlgITHP36rgxQ4FHgc+7e62Z9Xb3TYEUJCIibQryTGE0sNLdV7l7PTAHmNRszNXADHevBVAgiIiEK8hQ6Ad8FLe9LrYv3onAiWb2DzN7w8wmBFiPiIi0IS8Fnn8ocB7QH3jFzIa7+7b4QWZ2DXANQJcuXUYNGzYs2XWKiKS1RYsWbXb30rbGBRkK64EBcdv9Y/virQMWuHsDsNrMVhANiYXxg9x9FjALIBKJeGVlZWBFi4hkIjNbm8i4IC8fLQSGmtlgMysApgJzm435K9GzBMysF9HLSasCrElERA4jsFBw90bgBuA5YCnwqLtXmdktZjYxNuw5YIuZVQMvAT909y1B1SQiIocX2EdSg6LLRyIi7Wdmi9w90tY4dTSLiEgThYKIiDRRKIiISBOFgoiINMmaUFhVs4tbn11Guk2si4gkU9aEwovLNnHX3z/g3tdWh12KiEjKyppQuPJzg/liRRm/eGYZC1apFUJEpCVZEwpmxm1TTmNgSRE3/PktNu3YG3ZJIiIpJ2tCAaB7p3xmTh/Frr2NfGf2Yhr2Hwi7JBGRlJJVoQBwUp9u/PKi4SxcU8svn1kWdjkiIikl60IBYNKIflwxdhD3vraap979OOxyRERSRlaGAsC/XHgyIwf25EePv8vKTTvDLkdEJCVkbSgU5OXw+0tHUVSQy7UPLWLXvsawSxIRCV3WhgJAnx6duGPa6azevJsfP/6uGttEJOtldSgAjD2hFz+aMIz/WvKJGttEJOtlfSgAXHvu8WpsExFBoQCosU1E5CCFQowa20REFAqHUGObiGQ7hUIzamwTkWymUGiBGttEJFspFFqgxjYRyVYKhVaosU1EspFC4TDU2CYi2Uah0Ib4xrY3V28NuxwRkUApFNoQ39j2ndmL1dgmIhlNoZAANbaJSLYINBTMbIKZLTezlWZ2Uwv3X2FmNWb2duznqiDrORpqbBORbJAX1AObWS4wAxgHrAMWmtlcd69uNvQRd78hqDo60qQR/Vi8tpZ7X1vN6QN78uVTjw27JBGRDhXkmcJoYKW7r3L3emAOMCnA50uKm79UrsY2EclYQYZCP+CjuO11sX3NXWRm75rZ42Y2IMB6OoQa20Qkk4U90fwkMMjdTwWeBx5oaZCZXWNmlWZWWVNTk9QCW6LGNhHJVEGGwnog/p1//9i+Ju6+xd33xTbvAUa19EDuPsvdI+4eKS0tDaTY9lJjm4hkoiBDYSEw1MwGm1kBMBWYGz/AzPrGbU4ElgZYT4dTY5uIZJrAQsHdG4EbgOeI/rF/1N2rzOwWM5sYG3ajmVWZ2TvAjcAVQdUTBDW2iUimsXS7Hh6JRLyysjLsMg6xfMNOvjLjH5zSrzuzrx5Dfm7YUzUiIocys0XuHmlrnP56dQA1tolIplAodJBJI/px+VnH6RvbRCStKRQ6kBrbRCTdKRQ6UEFeDjMuHanGNhFJWwqFDta3R2c1tolI2lIoBECNbSKSrhQKAVFjm4ikI4VCQNTYJiLpSKEQIH1jm4ikG4VCwNTYJiLpRKGQBGpsE5F0oVBIEjW2iUg6UCgkiRrbRCQdKBSSSI1tIpLqFApJpsY2EUllCoUQqLFNRFKVQiEEamwTkVSlUAiJGttEJBUpFEKkxjYRSTUKhZCpsU1EUolCIQWosU1EUoVCIQUcbGzrnK/GNhEJl0IhRfTt0Zk7L1Fjm4iES6GQQtTYJiJhUyikGDW2iUiYFAopRo1tIhImhUIKUmObiIQl0FAwswlmttzMVprZTYcZd5GZuZlFgqwnnaixTUTCEFgomFkuMAO4ACgHpplZeQvjugHfBRYEVUu6UmObiCRbkGcKo4GV7r7K3euBOcCkFsb9O3AroIvnLVBjm4gkU5Ch0A/4KG57XWxfEzMbCQxw9/863AOZ2TVmVmlmlTU1NR1faQpTY5uIJFNoE81mlgP8BvhBW2PdfZa7R9w9UlpaGnxxKUaNbSKSLEGGwnpgQNx2/9i+g7oBpwB/N7M1wBhgriabW6bGNhFJhiBDYSEw1MwGm1kBMBWYe/BOd9/u7r3cfZC7DwLeACa6e2WANaU1NbaJSNACCwV3bwRuAJ4DlgKPunuVmd1iZhODet5MpsY2EQmapdv16Ugk4pWV2X0ysXzDTr4y4x8M79eDh68+k/xc9SCKyOGZ2SJ3b/PyvP6apKGDjW1vrtmqxjYR6VAKhTSlxjYRCYJCIY2psU1EOppCIY2psU1EOppCIc317dGZO6epsU1EOoZCIQOMHaLGNhHpGAqFDKHGNhHpCAqFDKHGNhHpCAqFDBL/jW03zH5L39gmIu2mUMgwamwTkaOhUMhAamwTkSOlUMhQamwTkSOhUMhQamwTkSOhUMhgamwTkfZSKGQ4NbaJSHsoFLKAGttEJFEKhSygxjYRSZRCIUuosU1EEqFQyCLxjW23qrFNRFqgUMgyBxvb7lFjm4i0QKGQhdTYJiKtUShkITW2iUhrFApZSo1tItKShELBzE4ws8LY7fPM7EYz6xlsaRK0sUN68cMvqrFNRD6V6JnCE8B+MxsCzAIGALMDq0qS5rp/UmObiHwq0VA44O6NwFeBO939h0Df4MqSZFFjm4jESzQUGsxsGnA58FRsX35bv2RmE8xsuZmtNLObWrj/OjNbYmZvm9lrZlaeeOnSUdTYJiIHJRoK3wTOAn7m7qvNbDDw0OF+wcxygRnABUA5MK2FP/qz3X24u48AfgX8pl3VS4dRY5uIAOQlMsjdq4EbAcysGOjm7re28WujgZXuvir2e3OASUB13OPuiBvfBdBHYEI0aUQ/Fq+t5Z7XVjNiYE++fOqxYZckIkmW6KeP/m5m3c2sBFgM3G1mbb2r7wd8FLe9Lrav+WN/x8w+IHqmcGNiZUtQ1Ngmkt0SvXzUI/au/mvAg+5+JvCFjijA3We4+wnAj4F/bWmMmV1jZpVmVllTU9MRTyutaN7YtnV3fdgliUgSJRoKeWbWF/g6n040t2U90Y+uHtQ/tq81c4CvtHSHu89y94i7R0pLSxN8ejlSfXt05neXjGRd7R6mzHydj7ftCbskEUmSREPhFuA54AN3X2hmxwPvt/E7C4GhZjbYzAqAqcDc+AFmNjRu80sJPKYkyVknHMOD3xrNph37mHzX66zctCvskkQkCRIKBXd/zN1PdffrY9ur3P2iNn6nEbiBaJgsBR519yozu8XMJsaG3WBmVWb2NvB9oh95lRRx5vHHMOfaMdTvP8CUma/zzkfbwi5JRAJmiax5Y2b9gTuBs2O7XgW+6+7rAqytRZFIxCsrK5P9tFlt9ebdfOPeBdTurmfWZRHOHtIr7JJEpJ3MbJG7R9oal+jlo/uJXvo5NvbzZGyfZIHBvbrwxPVj6V9cxDfvX8gzSz4JuyQRCUiioVDq7ve7e2Ps54+AZnyzSFn3Tjxy7RhO6ded78xezJ/f/DDskkQkAImGwhYzm25mubGf6cCWIAuT1NOzqIA/XXUm5wwt5Z//soS7/v6BltwWyTCJhsK3iH4cdQPwCTAZuCKgmiSFFRXkcfdlESaediy3PruMnz+9VMEgkkESXeZiLTAxfp+ZfQ+4PYiiJLUV5OVw+8UjKC7K5+5XV1Nb18AvvzacvFx9Z5NIujuaV/H3O6wKSTs5OcZPJlbwvS8M5fFF67juT4vZ27A/7LJE5CgdTShYh1UhacnM+N4XTuSWSRW8sGwjl9/3Jjv2NoRdlogchaMJBV1IFgAuO2sQt188gkVra5n6hzeo2bkv7JJE5AgdNhTMbKeZ7WjhZyfRfgURILrs9j2XR1i1eRdTZr7OR1vrwi5JRI7AYUPB3bu5e/cWfrq5e0KT1JI9zjupNw9fNYbaugYmz3yd5Ru09LZIutHHRaRDjTqumEevPQt3+Pof5rNobW3YJYlIOygUpMOd1KcbT1w/luKifKbfs4CXV+g7METShUJBAjGgpIjHrhvL4F5duOqBhcx95+OwSxKRBCgUJDCl3QqZc+0YTh9YzHfnvMVD89eEXZKItEGhIIHq3imfB781mvOH9ebf/l8V//G397UshkgKUyhI4Drl5zJz+iguGtmf3/5tBT99spoDBxQMIqlIHyuVpMjLzeG2yadSXJTPPa+tpraunl9POY18rZckklIUCpI0OTnGzV86mZKuBfzq2eVs39PAXZeOonNBbtiliUiM3qZJUpkZ3z5vCL/42nBeWVHD9HsXsL1O6yWJpAqFgoRi2uiBzLhkJEvWbefrf5jPxh17wy5JRFAoSIguGN6X+795Butq65g883XWbN4ddkkiWU+hIKE6e0gvZl89hl17G5k8cz5VH28PuySRrKZQkNCdNqAnj103lvxcY+of3uDN1VvDLkkkaykUJCUM6d2VJ64fS+/uhXzj3gW8sHRj2CWJZCWFgqSMY3t25rHrxjKsTzeueWgRTyxaF3ZJIllHoSAppaRLAQ9fPYYxx5fwg8fe4d7XVoddkkhWUShIyulamMd9V5zBBaf04d+fqubXzy3XekkiSRJoKJjZBDNbbmYrzeymFu7/vplVm9m7ZvaCmR0XZD2SPgrzcvndJSOZNnoAv3tpJTf/9T32a70kkcAFtsyFmeUCM4BxwDpgoZnNdffquGFvARF3rzOz64FfARcHVZOkl9wc4+dfHU5xUQG///sHbKur57cXj6AwT8tiiAQlyDOF0cBKd1/l7vXAHGBS/AB3f8ndD37D+xtA/wDrkTRkZvxowjBuvvBknl6ygSv/WMnufY1hlyWSsYIMhX7AR3Hb62L7WnMl8EyA9Ugau/rc4/n1lNOYv2oLl9yzgK2768MuSSQjpcREs5lNByLAba3cf42ZVZpZZU2Nvu83W00e1Z+Z00ex9JMdTJn5Oh9v2xN2SSIZJ8hQWA8MiNvuH9t3CDP7AnAzMNHd97X0QO4+y90j7h4pLS0NpFhJD+PKy3jwW6PZtGMfk+96nQ9qdoVdkkhGCTIUFgJDzWywmRUAU4G58QPM7HTgD0QDYVOAtUgGGXP8Mfz5mjHU7z/AlJnzWbJO6yWJdJTAQsHdG4EbgOeApcCj7l5lZreY2cTYsNuArsBjZva2mc1t5eFEDnFKvx48dt1YigpymTprPq9/sDnskkQygqVbU1AkEvHKysqwy5AUsXHHXi67901Wb97NHdNGMOGUvmGXJJKSzGyRu0faGpcSE80iR6qseyceuXYMp/TrzrcfXswjCz8MuySRtKZQkLTXs6iAP111JucMLeXHTyxh5ssfhF2SSNpSKEhGKCrI4+7LIkw87Vh++cwyfv70Uq2XJHIEAlvmQiTZCvJyuP3iERQX5TPrlVXU7q7nF18bTl6u3vuIJEqhIBklJ8f4ycQKirsUcPvf3mfbngbunHY6nfK1XpJIIvQWSjKOmfG9L5zITydW8Hz1Ri6/70127G0IuyyRtKBQkIx1+dhB/MfUESxaW8u0WW+weVeLDfMiEkehIBlt0oh+3H15hA9qdjFl5nw+2lrX9i+JZDGFgmS8/3ZSbx6+6ky27NrHlJnzWbFxZ9gliaQshYJkhVHHlfDodWdxwJ0pM+ez+MPasEsSSUkKBckaw/p054nrx9KzKJ9L717Ayyu0DLtIcwoFySoDSop4/LqxDOrVhaseWMiT73wcdkkiKUWhIFmntFshj1w7htMHFnPjnLd46I21YZckkjIUCpKVunfK58Fvjeb8Yb35t7++xx0vvK9lMURQKEgW65Sfy8zpo7hoZH9+8/wKfvpkNQcOKBgku2mZC8lqebk53Db5VIqL8rnntdVsq6vntimnka/1kiRLKRQk6+XkGDd/6WRKuhbwq2eXs31PA7+/dBSdC7RekmQfvR0SIbpe0rfPG8Ivvjacl1fU8I17F7C9TuslSfZRKIjEmTZ6IDMuGcm767Zz8az5bNqxN+ySRJJKoSDSzAXD+3LfFWfw4dY6Js+cz9otu8MuSSRpFAoiLfjc0F7MvnoMO/c2cNFd86n+eEfYJYkkhUJBpBUjBvTksevOIj/XuHjWfBau2Rp2SSKBUyiIHMaQ3t14/PqxlHYrZPo9C3hx2cawSxIJlEJBpA39enbmsWvP4qQ+3bj6wUX8ZfG6sEsSCYxCQSQBx3QtZPbVYzhzcAnff/Qd7nttddgliQRCoSCSoK6Fedz/zTOYUNGHW56q5v/OW671kiTjKBRE2qEwL5cZl45k2ugB3PniSm7+63vs13pJkkECDQUzm2Bmy81spZnd1ML955rZYjNrNLPJQdYi0lFyc4yff3U43z7vBGYv+JAb//wW+xr3h12WSIcIbO0jM8sFZgDjgHXAQjOb6+7VccM+BK4A/mdQdYgEwcz40YRhFBcV8LOnl7JjbwMzp4+iS6GWE5P0FuSZwmhgpbuvcvd6YA4wKX6Au69x93eBAwHWIRKYq889nl9POY3XP9jCJXe/wQtLN7K3QWcNkr6CfFvTD/gobnsdcOaRPJCZXQNcAzBw4MCjr0ykA00e1Z8enfP5waNvc+UDlRQV5HLu0FLGV5Tx+WG96VlUEHaJIglLi3Ndd58FzAKIRCKa1ZOUM668jMp/HceC1VuYV7WRedUbeLZqA7k5xpmDSxhfXsa4ij7069k57FJFDivIUFgPDIjb7h/bJ5KRCvJyOGdoKecMLeWnEytYsn4786o3MK9qIz95spqfPFnNKf26M768D+MryjiprBtmFnbZIoewoD5nbWZ5wArgfKJhsBC4xN2rWhj7R+Apd3+8rceNRCJeWVnZwdWKBGtVzS6er97IvOqNLP6wFncYUNI5GhDlZUQGlZCbo4CQ4JjZInePtDkuyOYbM7sQuB3IBe5z95+Z2S1ApbvPNbMzgP8EioG9wAZ3rzjcYyoUJN3V7NzHC0ujAfHays3UNx6gpEsB5w/rzfiKPpwztBed8vWtb9KxUiIUgqBQkEyya18jr6yoYV7VBl5YtomdexvpnJ/LOUN7Mb6iD+cP601xF01Uy9FLNBTSYqJZJFN1LczjwuF9uXB4Xxr2H2DBqq08X72BebFLTbk5xhmDihlf3odx5WUMKCkKu2TJcDpTEElB7s5763c0TVQv37gTgPK+3RlfUca48jLK+3bXRLUkTJePRDLIms27YxPVG6hcG52o7tezM+Mryhhf3oczBhWTl6ulzKR1CgWRDLV51z5eXLqJedUbeOX96ER1z6J8zh9WxviKMs4dWkrnAk1Uy6EUCiJZYPe+Rl59v4Z5VRt5Ydkmtu9poDDWLzG+oozzh/XmmK6FYZcpKUATzSJZoEthHhNO6cuEU6IT1QtXb2Ve9Uaer97I35ZuJMcgMijaUT2+vA8Dj9FEtRyezhREMpC7U/XxjuinmKo2sGxDdKJ6WJ9ujK+INsxVHKuJ6myiy0ci0uTDLXXRTzJVb6RyzVYOxCaqx5WXMb68jDMGl5CvieqMplAQkRZt3V3f1FH9yooa9jUeoEfn/FhHdRnnnlhKUYGuLGcahYKItKmuvpFX398cm6jeyLa66ET154b0ik5Un1xGL01UZwRNNItIm4oK8vhiRR++WNGHxv0HWLimluerN/JcbNkNsyVEjvu0o3pQry5hlywB05mCiHyGu7P0k51NHdXVn+wA4MSyrk1Lfw/v10MT1WlEl49EpMN8tLWuqaP6zdXRieq+PTrFJqr7cObxmqhOdQoFEQlE7e56XlwW7ah+eUUNexsO0K1THucP68248j7800mldC3UlelUo1AQkcDtqd/Pays3M69qA39bupHaugYKcnM4e8gx0aW/T+5N726dwi5T0ESziCRB54JcxpVHV21t3H+ARWtjE9XVG3jpL0swg5EDi6Md1RV9GKyJ6pSnMwUR6XDuzvKNO5lXFZ2HeG99dKJ6SO+uTQFxar8e5OgrSJNGl49EJGWs37aH56uiHdULVm9l/wGnV9dCBpZ0pqRLISVd8inuUkBJUQHFXQo4pkvBIdvdO+Xpk05HSaEgIilpW109Ly3fxKsrNrNx51627m6gdnc9W3fXU7//QIu/k5djcSGRT0mXguhPLDRKuhRQXFTw6f4uBfqe62Y0pyAiKalnUQFfPb0/Xz29/yH73Z3d9fubAmJrXf2nt3fXU1sX++fuBpZv2EltXQO1dfW09r62c35uU0BEA+XTs5GSrp8NlOKifH1REQoFEUkRZkbXwjy6FuYl/F3U+w84O/Y0sLUuLjxigbJ1V1yw1DWwZvNutu6uZ9e+xlYfr0fn/FhI5DcLlOy5rKVQEJG0lRu7rFTcpYATShP7nX2N+9lW19AUIFsOOQv5dHv9tr28t35Huy9rFRfFhUcaXtZSKIhIVinMy6Wsey5l3RPrn0j0stbW3fXtuqwVDZHCQy5rHXI2EtJlLYWCiMhhHM1lreZnIUd7Wet/jDuRiacd21H/ai1SKIiIdLD4y1qJir+s9dnJ9ehlrZKixB/vSCkURERSQHsvawUl0AtVZjbBzJab2Uozu6mF+wvN7JHY/QvMbFCQ9YiIyOEFFgpmlgvMAC4AyoFpZlbebNiVQK27DwF+C9waVD0iItK2IM8URgMr3X2Vu9cDc4BJzcZMAh6I3X4cON8y7UO/IiJpJMhQ6Ad8FLe9LravxTHu3ghsB44JsCYRETmMtOjpNrNrzKzSzCpramrCLkdEJGMFGQrrgQFx2/1j+1ocY2Z5QA9gS/MHcvdZ7h5x90hpaYJtiyIi0m5BhsJCYKiZDTazAmAqMLfZmLnA5bHbk4EXPd2WbRURySCB9Sm4e6OZ3QA8B+QC97l7lZndAlS6+1zgXuAhM1sJbCUaHCIiEpK0+z4FM6sB1h7hr/cCNndgOR1FdbWP6mq/VK1NdbXP0dR1nLu3ef097ULhaJhZZSJfMpFsqqt9VFf7pWptqqt9klFXWnz6SEREkkOhICIiTbItFGaFXUArVFf7qK72S9XaVFf7BF5XVs0piIjI4WXbmYKIiBxGRoZCqi7ZnUBdV5hZjZm9Hfu5Kkl13Wdmm8zsvVbuNzO7I1b3u2Y2MkXqOs/Mtscdr/+VhJoGmNlLZlZtZlVm9t0WxiT9eCVYVxjHq5OZvWlm78Tq+mkLY5L+ekywrlBej7HnzjWzt8zsqRbuC/Z4uXtG/RBtlPsAOB4oAN4BypuN+TYwM3Z7KvBIitR1BfC7EI7ZucBI4L1W7r8QeAYwYAywIEXqOg94KsnHqi8wMna7G7Cihf+OST9eCdYVxvEyoGvsdj6wABjTbEwYr8dE6grl9Rh77u8Ds1v67xX08crEM4VUXbI7kbpC4e6vEO0ob80k4EGPegPoaWZ9U6CupHP3T9x9cez2TmApn139N+nHK8G6ki52DHbFNvNjP80nMpP+ekywrlCYWX/gS8A9rQwJ9HhlYiik6pLdidQFcFHsksPjZjaghfvDkGjtYTgrdgngGTOrSOYTx07bTyf6LjNeqMfrMHVBCMcrdinkbWAT8Ly7t3q8kvh6TKQuCOf1eDvwI+BAK/cHerwyMRTS2ZPAIHc/FXieT98NSMsWE23dPw24E/hrsp7YzLoCTwDfc/cdyXretrRRVyjHy933u/sIoisljzazU5LxvG1JoK6kvx7N7MvAJndfFPRztSYTQ6HDluxOdl3uvsXd98U27wFGBVxTohI5pknn7jsOXgJw96eBfDPrFfTzmlk+0T+8D7v7X1oYEsrxaquusI5X3PNvA14CJjS7K4zXY5t1hfR6PBuYaGZriF5i/ryZ/anZmECPVyaGQqou2d1mXc2uO08kel04FcwFLot9qmYMsN3dPwm7KDPrc/BaqpmNJvr/c6B/TGLPdy+w1N1/08qwpB+vROoK6XiVmlnP2O3OwDhgWbNhSX89JlJXGK9Hd/9nd+/v7oOI/o140d2nNxsW6PEKbOnssHiKLtmdYF03mtlEoDFW1xVB1wVgZn8m+smUXma2DvjfRCfecPeZwNNEP1GzEqgDvpkidU0GrjezRmAPMDUJ4X428A1gSex6NMC/AAPj6grjeCVSVxjHqy/wgJnlEg2hR939qbBfjwnWFcrrsSXJPF7qaBYRkSaZePlIRESOkEJBRESaKBRERKSJQkFERJooFEREpIlCQaQZM9sftzLm29bCirZH8diDrJVVX0VSQcb1KYh0gD2x5Q9Eso7OFEQSZGZrzOxXZrYkthb/kNj+QWb2YmzhtBfMbGBsf5mZ/WdsAbp3zGxs7KFyzexui67jPy/WUSuSEhQKIp/Vudnlo4vj7tvu7sOB3xFdzRKii8s9EFs47WHgjtj+O4CXYwvQjQSqYvuHAjPcvQLYBlwU8L+PSMLU0SzSjJntcveuLexfA3ze3VfFFp/b4O7HmNlmoK+7N8T2f+LuvcysBugft6jawWWtn3f3obHtHwP57v5/gv83E2mbzhRE2sdbud0e++Ju70dze5JCFAoi7XNx3D/nx26/zqeLkl0KvBq7/QJwPTR9oUuPZBUpcqT0DkXkszrHrTQK8Ky7H/xYarGZvUv03f602L7/DtxvZj8Eavh0VdTvArPM7EqiZwTXA6EvOS5yOJpTEElQbE4h4u6bw65FJCi6fCQiIk10piAiIk10piAiIk0UCiIi0kShICIiTRQKIiLSRKEgIiJNFAoiItLk/wMwFW0bFd21HgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x110c3b048>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Model trained.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model = CNN(in_channels = 300, out_channels = 2, batch_size = 10)\n",
    "model.train(train_iter = train_iter, val_iter = val_iter, num_epochs = 5, learning_rate = 1e-3)\n",
    "model.test(test_iter)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
